import numpy as np
import tensorflow as tf
from tensorflow.python.training import moving_averages

from src.common_ops import create_weight
from src.common_ops import create_bias


def drop_path(x, keep_prob):
    """Drops out a whole example hiddenstate with the specified probability."""

    batch_size = tf.shape(x)[0]
    noise_shape = [batch_size, 1, 1, 1]
    random_tensor = keep_prob
    random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
    binary_tensor = tf.floor(random_tensor)
    x = tf.div(x, keep_prob) * binary_tensor

    return x


def conv(x, filter_size, out_filters, stride, name="conv", padding="SAME",
         data_format="NHWC", seed=None):
    """
    Args:
      stride: [h_stride, w_stride].
    """

    if data_format == "NHWC":
        actual_data_format = "channels_last"
    elif data_format == "NCHW":
        actual_data_format = "channels_first"
    else:
        raise NotImplementedError("Unknown data_format {}".format(data_format))
    x = tf.layers.conv2d(
        x, out_filters, [filter_size, filter_size], stride, padding,
        data_format=actual_data_format,
        kernel_initializer=tf.contrib.keras.initializers.he_normal(seed=seed))

    return x


def fully_connected(x, out_size, name="fc", seed=None):
    in_size = x.get_shape()[-1].value  
    with tf.variable_scope(name):
        w = create_weight("w", [in_size, out_size], seed=seed)
    x = tf.matmul(x, w)
    return x


def max_pool(x, k_size, stride, padding="SAME", data_format="NHWC",
             keep_size=False):
    """
    Args:
      k_size: two numbers [h_k_size, w_k_size].
      stride: two numbers [h_stride, w_stride].
    """

    if data_format == "NHWC":
        actual_data_format = "channels_last"
    elif data_format == "NCHW":
        actual_data_format = "channels_first"
    else:
        raise NotImplementedError("Unknown data_format {}".format(data_format))
    out = tf.layers.max_pooling2d(x, k_size, stride, padding,
                                  data_format=actual_data_format)

    if keep_size:
        if data_format == "NHWC":
            h_pad = (x.get_shape()[1].value - out.get_shape()[1].value) // 2
            w_pad = (x.get_shape()[2].value - out.get_shape()[2].value) // 2
            out = tf.pad(out, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]])
        elif data_format == "NCHW":
            h_pad = (x.get_shape()[2].value - out.get_shape()[2].value) // 2
            w_pad = (x.get_shape()[3].value - out.get_shape()[3].value) // 2
            out = tf.pad(out, [[0, 0], [0, 0], [h_pad, h_pad], [w_pad, w_pad]])
        else:
            raise NotImplementedError(
                "Unknown data_format {}".format(data_format))
    return out


def global_avg_pool(x, data_format="NHWC"):
    if data_format == "NHWC":
        x = tf.reduce_mean(x, [1, 2])
    elif data_format == "NCHW":
        x = tf.reduce_mean(x, [2, 3])
    else:
        raise NotImplementedError("Unknown data_format {}".format(data_format))
    return x


def batch_norm(x, is_training, name="bn", decay=0.9, epsilon=1e-5,
               data_format="NHWC"):
    if data_format == "NHWC":
        shape = [x.get_shape()[3]]
    elif data_format == "NCHW":
        shape = [x.get_shape()[1]]
    else:
        raise NotImplementedError("Unknown data_format {}".format(data_format))

    with tf.variable_scope(name, reuse=None if is_training else True):
        offset = tf.get_variable(
            "offset", shape,
            initializer=tf.constant_initializer(0.0, dtype=tf.float32))
        scale = tf.get_variable(
            "scale", shape,
            initializer=tf.constant_initializer(1.0, dtype=tf.float32))
        moving_mean = tf.get_variable(
            "moving_mean", shape, trainable=False,
            initializer=tf.constant_initializer(0.0, dtype=tf.float32))
        moving_variance = tf.get_variable(
            "moving_variance", shape, trainable=False,
            initializer=tf.constant_initializer(1.0, dtype=tf.float32))

        if is_training:
            x, mean, variance = tf.nn.fused_batch_norm(
                x, scale, offset, epsilon=epsilon, data_format=data_format,
                is_training=True)
            update_mean = moving_averages.assign_moving_average(
                moving_mean, mean, decay)
            update_variance = moving_averages.assign_moving_average(
                moving_variance, variance, decay)
            with tf.control_dependencies([update_mean, update_variance]):
                x = tf.identity(x)
        else:
            x, _, _ = tf.nn.fused_batch_norm(x, scale, offset, mean=moving_mean,
                                             variance=moving_variance,
                                             epsilon=epsilon, data_format=data_format,
                                             is_training=False)
    return x


def batch_norm_with_mask(x, is_training, mask, num_channels, name="bn",
                         decay=0.9, epsilon=1e-3, data_format="NHWC"):

    shape = [num_channels]
    indices = tf.where(mask)
    indices = tf.to_int32(indices)
    indices = tf.reshape(indices, [-1])

    with tf.variable_scope(name, reuse=None if is_training else True):
        offset = tf.get_variable(
            "offset", shape,
            initializer=tf.constant_initializer(0.0, dtype=tf.float32))
        scale = tf.get_variable(
            "scale", shape,
            initializer=tf.constant_initializer(1.0, dtype=tf.float32))
        offset = tf.boolean_mask(offset, mask)
        scale = tf.boolean_mask(scale, mask)

        moving_mean = tf.get_variable(
            "moving_mean", shape, trainable=False,
            initializer=tf.constant_initializer(0.0, dtype=tf.float32))
        moving_variance = tf.get_variable(
            "moving_variance", shape, trainable=False,
            initializer=tf.constant_initializer(1.0, dtype=tf.float32))

        if is_training:
            x, mean, variance = tf.nn.fused_batch_norm(
                x, scale, offset, epsilon=epsilon, data_format=data_format,
                is_training=True)
            mean = (1.0 - decay) * (tf.boolean_mask(moving_mean, mask) - mean)
            variance = (1.0 - decay) * \
                (tf.boolean_mask(moving_variance, mask) - variance)
            update_mean = tf.scatter_sub(
                moving_mean, indices, mean, use_locking=True)
            update_variance = tf.scatter_sub(
                moving_variance, indices, variance, use_locking=True)
            with tf.control_dependencies([update_mean, update_variance]):
                x = tf.identity(x)
        else:
            masked_moving_mean = tf.boolean_mask(moving_mean, mask)
            masked_moving_variance = tf.boolean_mask(moving_variance, mask)
            x, _, _ = tf.nn.fused_batch_norm(x, scale, offset,
                                             mean=masked_moving_mean,
                                             variance=masked_moving_variance,
                                             epsilon=epsilon, data_format=data_format,
                                             is_training=False)
    return x


def relu(x, leaky=0.0):
    return tf.where(tf.greater(x, 0), x, x * leaky)
